"""
    Title

    author: wxz
    date: 2021-12-28
    github: https://github.com/xinzwang

    ** SoftWare ISP Hyperparameter
        --  version: 1.0.0

              type: numpy.ndarray()
              shape: 20

              black level correction   params: 2  [black_level, white_level]
              demosaik                 params: 0
              white balance            params: 3  [gain_1, gain_2, gain_3]
              color correction         params: 9  [c_00, c_01, c_02, c_10, c_11, c_12, c_31, c_32,c_33]
              gamma correction         params: 2  [mul, gamma]
              tone curve               params: 4  [mul, kernel_w, kernel_h, sigma]

"""
import os
import cv2
import glob
import re
import json
import numpy as np
import rawpy
from threading import Lock

from isp.blocks.black_level_correction import black_level_correction
from isp.blocks.color_correction import color_correction
from isp.blocks.demosaik import demosaik
from isp.blocks.gamma_correction import gamma_correction
from isp.blocks.tone_curve import tone_curve
from isp.blocks.white_balance import channel_gain_white_balance


class ISP_base(object):
    """ ISP base class """

    def set_params(self, p):
        pass

    def forward(self, raw):
        pass

    def __len__(self):
        pass


class ISP_rawpy(ISP_base):
    """ ISP pipline based on rawpy """

    def __init__(self):
        self.dim_params = 6
        self.use_auto_wb = False  # (0, 1)    自动白平衡
        self.gamma = (2.222, 4.5)  # (R, R)    gamma矫正
        self.exp_shift = 1  # 0.25-8.0  线性比例曝光偏移
        self.exp_preserve_highlights = 0.5  # 0.0 -1.0  使用exp_shift增亮图像时保留高光
        self.user_sat = 0.0  # R         饱和度调整
        self.save_count = 0
        self.lock = Lock()
        self.p = np.array([0, 2.222, 4.5, 1.0, 0.5, 0.0])

    def init_params(self):
        out = np.array([0, 2.222, 4.5, 1.0, 0.5, 0.0])
        return out

    def set_params(self, p):
        self.p = p
        self.use_auto_wb = ((p[0] > 0.5) if p[0] > 0 else False) if p[0] < 1 else True
        self.gamma = (p[1], p[2])
        self.exp_shift = (p[3] if p[3] > 0.25 else 0.25) if p[3] < 8 else 8
        self.exp_preserve_highlights = (p[4] if p[4] > 0 else 0) if p[4] < 1 else 1
        self.user_sat = p[5]

    def get_params(self):
        out = f'params:{self.p}\n\n'
        out += f'use_auto_wb:{self.use_auto_wb}\n'
        out += f'gamma:{self.gamma}\n'
        out += f'exp_shift:{self.exp_shift}\n'
        out += f'exp_preserve_highlights:{self.exp_preserve_highlights}\n'
        out += f'user_sat:{self.user_sat}\n'
        return out

    def forward(self, raw):
        rgb = raw.postprocess(
            use_auto_wb=self.use_auto_wb,
            gamma=self.gamma,
            exp_shift=self.exp_shift,
            exp_preserve_highlights=self.exp_preserve_highlights,
            user_sat=self.user_sat)
        return rgb


class ISP(ISP_base):
    def __init__(self, save_fig=False):
        self.dim_params = 20
        self.save_fig = save_fig
        return

    def forward(self, raw):
        raw = raw if isinstance(raw, np.ndarray) else raw.raw_image_visible
        # version 1.0.0     p->params
        p = self.p
        assert raw is not None, print('[ERROR] input raw data is None')
        assert p.shape == (20,), print('[ERROR] params shape error. shape:', p.shape)
        from matplotlib import pyplot as plt
        img = black_level_correction(raw, p[0], p[1])
        if self.save_fig:
            plt.subplot(2, 3, 1)
            plt.title('black_level_correction')
            plt.imshow(img)

        img = demosaik(img)
        if self.save_fig:
            plt.subplot(2, 3, 2)
            plt.title('demosaik')
            plt.imshow(img)

        img = channel_gain_white_balance(img, p[2:5])
        if self.save_fig:
            plt.subplot(2, 3, 3)
            plt.title('white balance')
            plt.imshow(img)

        img = color_correction(img, p[5:14].reshape(3, 3))
        if self.save_fig:
            plt.subplot(2, 3, 4)
            plt.title('color correction')
            plt.imshow(img)

        img = gamma_correction(img, p[14], p[15])
        if self.save_fig:
            plt.subplot(2, 3, 5)
            plt.title('gamma')
            plt.imshow(img)

        img = tone_curve(img, p[16], p[17:19], p[19])
        if self.save_fig:
            plt.subplot(2, 3, 6)
            plt.title('tone curve')
            plt.imshow(img)
            plt.savefig('isp.jpg', dpi=300)
        return img * 255

    def set_params(self, p):
        self.p = p
        return

    @staticmethod
    def imread(raw_path):
        # read raw iamge. Using rawpy
        out = []
        for p in raw_path if isinstance(raw_path, list) else [raw_path]:
            raw = rawpy.imread(p)
            out.append(raw)
        return out

    @staticmethod
    def init_params():
        # init params of isp
        SID_RAW_params = np.array(
            [512, 16383, 1.9296875, 1.0, 2.26171875, .9020, -.2890, -.0715, -.4535, 1.2436, .2348, -.0934, .1919, .7086,
             80.0, 2.2, 0.1, 3.0, 3.0, 1.0])
        LOD_RAW_params = np.array(
            [2047, 14448, 2.074697256088257, 0.9324925541877747, 1.1760492324829102, .9020, -.2890, -.0715, -.4535,
             1.2436, .2348, -.0934, .1919, .7086,
             80.0, 2.2, 0.1, 3.0, 3.0, 1.0])
        return LOD_RAW_params

    @staticmethod
    def save_params(p=None, path="./runs/isp", suffix=""):
        # save isp params to json file
        a = {
            "black_level_correction": {
                "black_level": p[0],
                "white_level": p[1]
            },
            "white_balance": p[2:5].tolist(),
            "color_correction": p[5:14].tolist(),
            "gamma_correction": {
                "mul": p[14],
                "gamma": p[15]
            },
            "tone_curve": {
                "mul": p[16],
                "kernel": p[17:19].tolist(),
                "sigma": p[19]
            }
        }
        path = path.replace('/', os.sep)
        if not path.split('.')[-1] == 'json':
            files = glob.glob(str(path + "\\*.json"))
            num = [int(re.search('[0-9]+', x.split('.')[-2]).group()) for x in files]  # find suffix num
            num = 0 if len(num) == 0 else max(num) + 1
            path = path + os.sep + "params_" + suffix + str(num) + ".json"
        with open(path, 'w') as f:
            f.write(json.dumps(a))
        return

    @staticmethod
    def load_params(path):
        # load isp params from json file
        a = None
        with open(path) as f:
            a = json.loads(f.read())
        p = [
            a["black_level_correction"]['black_level'],
            a["black_level_correction"]['white_level'],
            a["white_balance"],
            a["color_correction"],
            a["gamma_correction"]["mul"],
            a["gamma_correction"]["gamma"],
            a["tone_curve"]["mul"],
            a["tone_curve"]["kernel"],
            a["tone_curve"]["sigma"]
        ]
        p = flatten(p)
        p = np.array(p)
        return p


# ---------------- utils function ----------------
# Flatten the list to 1-D
flatten = lambda x: [y for l in x for y in flatten(l)] if type(x) is list else [x]

if __name__ == '__main__':
    a = ISP()
    p = a.init_params()
